Introduction

The goal of this project is to develop a model that can detect if a mushroom is poisonous based on its characteristics.

What are mushrooms?

Mushrooms are members of the Fungi kingdom, which includes over 144,000 species such as mold, yeasts, rusts and many more. Some species of mushroom are edible and safe to eat, while others are poisonous and can cause serious damage when consumed. Mushrooms are commonly found alongside plants, and often lean on them for support as they grow. Mushrooms, though sometimes mistaken as plants due to their edible characteristics, do not belong to the Plantae kingdom. The body of a mushroom is made up of several distinct features, such as a stalk, commonly known as a stem, and a disc-shaped cap. Beneath the cap of many different mushrooms species is a series of closely spaced slits called gills which often appear on the edible mushrooms found in supermarkets. However, in some species of mushroom, this region can be covered by pores. In addition to these physical characteristics, mushrooms can be identified by their odor, color, and even habitat.

Why might this model be useful?

Certain mushroom species, such as the Amanita or Cortinarius species, are extremely dangerous and almost always fatal when eaten. The Amanita Phalloides mushroom, also known as the “death cap”, is perhaps the most deadly mushroom species. The frightening thing about this species of mushroom is that they very closely resemble edible straw mushrooms. Image you were starving on a remote island and the only source of food around you was different types of mushroom, wouldn’t you like a model to help predict which of those mushrooms were safe to eat?

Loading Data and Packages

This project uses a Kaggle data set named “Mushroom Classification Updated Dataset” that includes hypothetical samples pertaining to twenty three different gilled mushroom species. These mushroom species belong to the Agaricus and Lepiota Family of Mushroom and were take from The Audubon Society Field Guide to North American Mushrooms (1981). Each data point is labeled as either edible or poisonous and contains information regarding the mushrooms physical characteristics.

Although I have provided a full .csv file of the codebook in the github repository, I have listed below the key variables used in this report:

  • class : Mushroom class; either ‘edible’ or ‘poisonous’

  • cap_shape : The shape of the mushrooms cap

  • bruises: Whether or not the mushroom contains bruises

  • odor : The mushrooms odor

  • gill_size : The size of the mushrooms gills

  • gill_color : The color of the mushrooms gills

  • stalk_root : The type of root on the mushroom

  • stalk_color_above_ring : The color of the stalk above the ring of the mushroom

  • stalk_color_below_ring : The color of the stalk below the ring of the mushroom

  • spore_print_color : The color of the mushrooms spore print

  • population : The density of mushrooms in the location the mushroom was found

  • habitat : The type of environment where the mushroom was found

library(ranger)
library(janitor)
library(rpart.plot)
library(ggplot2)
library(tidyverse)
library(tidymodels)
library(tibble)
library(corrplot)
library(yardstick)
library(corrr)
library(pROC)
library(glmnet)
library(ggthemes)
library(vip)
library(xgboost)
library(kknn)
library(psych)
library(dplyr)
library(knitr)
library(haven)
library(sjlabelled) # package to read and write item labels and values
library(lubridate, warn.conflicts = FALSE)
tidymodels_prefer()
shrooms <- read_csv("/Users/michellesack/Desktop/UCSB/Senior yr/Spring/PSTAT 131/mushroomsupdated copy.csv")

Data Cleaning

Before performing a split of the data, a couple cleaning steps must be taken to organize the data.

I will first clean up the names of the variables in the date set using the clean_names() function.

shrooms <- shrooms %>% 
  clean_names()

Next, I will get rid of several unwanted variables.

shrooms <- shrooms %>% 
  select(-cap_surface, -cap_color, -gill_attachment, -gill_spacing, -stalk_surface_above_ring, -stalk_surface_below_ring, -veil_color, -ring_number, -ring_type)

Data Split

In splitting the data, I will select 80% for training and leave the remaining 20% for testing. To account for a skewed distribution in the outcome variable, I will use stratified sampling.

I have chosen to split the data before conducting the exploratory data analysis to make sure I am not influenced by the information given in the testing set.

shrooms_split <- shrooms %>% 
  initial_split(prop = 0.8, strata = 'class')

shrooms_train <- training(shrooms_split)
shrooms_test <- testing(shrooms_split)

# check that the data was split correctly
nrow(shrooms_train)
nrow(shrooms_test)

Exploratory Data Analysis

This section of my report will focus only on the training set, which contains 6,498 observations. Each observation represents a unique mushroom.

Odor

Unpleasant odors are often a strong indication of poisonous mushrooms, as well as a white spore print. For this reason, I will begin my EDA by examining the distribution of poisonous mushrooms contained in each odor class.

ggplot(shrooms_train, aes(odor)) +
  geom_bar(fill = c('lightgreen')) +
  facet_wrap(~class, scales = "free_y") +
  labs(title = "Distribution of Odor  by Class", x = "Odor", y = "Count") +
  coord_flip()

ggplot(shrooms_train, aes(fill = odor, x = class)) +
  geom_bar(stat="count", color = "black") +
  facet_wrap(~odor, scales = "free_y") +
  labs(title = "Histogram of Class by Odor Types") + 
  theme(legend.position = 'none')

The histogram for the distribution of odor by class demonstrates that the most common smell a poisonous mushroom is foul, while the most common smell for an edible mushroom is none. The histogram of class by odor types clearly depicts a strong correlation between the two variables. Specifically, we can see that edible mushrooms never seem to have a foul, fishy, spicy, pungent, or creosote smell, meaning these characteristics are a strong indication of a poisonous mushroom. On the other hand, poisonous mushrooms never seem to have an anise or almond smell, which means these two odors are a strong indication of edible mushrooms. It appears as though the most misleading odor type is ‘none’, as both edible and poisonous mushrooms depict having no smell.

Spore Print Color

Similarly to odor type, the spore print color of a mushroom is often a great indication of whether or not a mushroom is edible. Let’s take a look at the distribution of poisonous mushrooms in each spore print color type.

ggplot(shrooms_train, aes(spore_print_color)) +
  geom_bar(fill = c('lightgreen')) +
  facet_wrap(~class, scales = "free_y") +
  labs(title = "Distribution of Spore Print Color by Class", x = "Spore Print Color", y = "Count") +
  coord_flip()

ggplot(shrooms_train, aes(fill = spore_print_color, x = class)) +
  geom_bar(stat="count", color = "black") +
  facet_wrap(~spore_print_color, scales = "free_y") +
  labs(title = "Histogram of Class by Spore Print Colors") + 
  theme(legend.position = 'none')

We can see here that the most common spore print colors for poisonous mushrooms are brown and black, while the most common spore print colors for edible mushrooms are white and chocolate. The histogram for the distribution of class by spore print color helps support our prior assumption that spore print color is a strong indicator of class type, demonstrating in particular that chocolate, green, and white spore print colors are strongly correlated with poisonous mushrooms.

Stalk Shape and Veil Type

Let’s take a look at the relationship between class type and the variables stalk shape and veil type

ggplot(shrooms_train, aes(class)) +
  geom_bar(stat="count", fill = c('lightyellow')) +
  facet_wrap(~stalk_shape, scales = "free_y") +
  labs(title = "Histogram of Class by Stalk Shape")

ggplot(shrooms_train, aes(class)) +
  geom_bar(stat="count", fill = c('lightyellow')) +
  facet_wrap(~veil_type, scales = "free_y") +
  labs(title = "Histogram of Class by Veil Type")

The histograms of class by stalk shape reveals a near even distribution between edible mushrooms and poisonous mushrooms, telling us that the stalk shape variable is a poor indicator of class type. We can also see from the histogram of class by veil type that their only exists one veil type option within our data set. This means that the veil type variable will have no contribution in determining our outcome variable.

Therefore, we will chose to remove these variables from our analysis.

shrooms_train <- shrooms_train %>% 
  select(-stalk_shape, -veil_type)
shrooms_test <- shrooms_test %>% 
  select(-stalk_shape, -veil_type)

Stalk Color Above Ring and Stalk Color Below Ring

ggplot(shrooms_train, aes(class, stalk_color_above_ring)) +
  geom_boxplot() + 
  geom_jitter(alpha = 0.3) +
  xlab("Class") + 
  ylab("Stalk Color Above Ring") +
  labs(title = "Box Plot of Class by Stalk Color Above Ring")

ggplot(shrooms_train, aes(class, stalk_color_below_ring)) +
  geom_boxplot() + 
  geom_jitter(alpha = 0.3) +
  xlab("Class") + 
  ylab("Stalk Color Below Ring") +
  labs(title = "Box Plot of Class by Stalk Color Below Ring")

Notice how the two box plots demonstrate an almost identical distribution. Both stalk_color_above_ring and stalk_color_below_ring demonstrate that an orange, gray or red stock color is a very strong characteristic of edible mushrooms, while a buff stalk color is a strong characteristic of poisonous mushrooms. We can also see that a white stock color is very evenly distributed between both classes.

Gill Color

ggplot(shrooms_train, aes(class, gill_color)) +
  geom_boxplot() + 
  geom_jitter(alpha = 0.3) +
  xlab("Class") +
  ylab("Gill Color") +
  labs(title = "Box Plot of Class by Gill Color")

ggplot(shrooms_train, aes(class, gill_size)) +
  geom_boxplot() + 
  geom_jitter(alpha = 0.3) +
  xlab("Class") +
  ylab("Gill Size") +
  labs(title = "Box Plot of Class by Gill Size")

Looking at a box plot of class by gill color, we see a strong relationship between a buff colored gill and poisonous mushrooms. We can also note that mushrooms with a brown or black gill color have a higher likelihood of being in the edible class, while mushrooms with a gray or chocolate gill color have a higher likelihood of being in the poisonous class. The data also demonstrates that a pink gill color denotes a near even distribution and will thus have a weak impact on are prediction. Lastly, the box plot of class by gill size tell us that edible mushrooms often have a broad gill size, while poisonous mushrooms often have a narrow gill size.

Habitat

ggplot(shrooms_train, aes(class)) +
  geom_bar(stat="count", color = "black") +
  facet_wrap(~habitat, scales = "free_y") +
  labs(title = "Histogram of Class by Habitat")

ggplot(shrooms_train, aes(class, habitat)) +
  geom_boxplot() + 
  geom_jitter(alpha = 0.3) +
  xlab("Class") +
  ylab("Habitat") +
  labs(title = "Box Plot of Class by Habitat")

Although poisonous mushrooms are most commonly found near wood habitats, we can see that this characteristic is not a strong identifier since an even greater amount of edible mushrooms are also found near wood habitats. On the other hand, we find that an abundance of poisonous mushrooms are found on paths, where very few mushrooms are found, and so we can argue that paths is a good indicator for poisonous mushrooms! We see a similar case for leafy habitats, though not as prominent. Lastly, we find that no poisonous mushrooms were found in waste habitats, with very few found in meadows.

Population

ggplot(shrooms_train, aes(class)) +
  geom_bar(stat="count", color = "black") +
  facet_wrap(~population, scales = "free_y") +
  labs(title = "Histogram of Class by Population")

ggplot(shrooms_train, aes(class, population)) +
  geom_boxplot() + 
  geom_jitter(alpha = 0.3) +
  xlab("Class") +
  ylab("Population") +
  labs(title = "Box Plot of Class by Population")

These class by population graphs demonstrate that poisonous mushrooms are rarely clustered together in groups and more often spread out and scattered. The population distribution of edible mushrooms is much more evenly distributed, with a slightly stronger correlation towards scattered and isolated populations

Remaining Predictors

Let’s now take a look at the distribution of class among our remaining predictors to try and determine which factors are the most important identifiers of poisonous mushrooms.

ggplot(shrooms_train, aes(class)) +
  geom_bar(stat="count", fill = c('lightpink')) +
  facet_wrap(~cap_shape, scales = "free_y") +
  labs(title = "Histogram of Class by Cap Shape ")

ggplot(shrooms_train, aes(class)) +
  geom_bar(stat="count", fill = c('lightpink')) +
  facet_wrap(~bruises, scales = "free_y") +
  labs(title = "Histogram of Class by Bruises")  

ggplot(shrooms_train, aes(class)) +
  geom_bar(stat="count", fill = c('lightpink')) +
  facet_wrap(~stalk_root, scales = "free_y") +
  labs(title = "Histogram of Class by Stalk Root")

The histograms for these variables depict a much more even distribution of class than the other variables. Therefore, we can assume that cap shape, bruises, and stalk shape will have less of an impact in determining our outcome variable.

Correlation Plot

In order to construct a correlation plot, we must first convert character data from the training set into a numerical data.

# Convert each character variable into a factor variable
shrooms_train_factor <- shrooms_train
shrooms_train_factor[] <- lapply(shrooms_train_factor, factor)

# Convert each factor variable into a numeric variable
shrooms_train_numeric <- shrooms_train_factor
shrooms_train_numeric[] <- as.data.frame(sapply(shrooms_train_numeric, as.numeric))

To use in a later analysis, I will perform the same conversion on the testing set

# Convert each character variable into a factor variable
shrooms_test_factor <- shrooms_test
shrooms_test_factor[] <- lapply(shrooms_test_factor, factor)

# Convert each factor variable into a numeric variable
shrooms_test_numeric <- shrooms_test_factor
shrooms_test_numeric[] <- as.data.frame(sapply(shrooms_test_numeric, as.numeric))

Now that the data frame is in the correct format, we can construct a correlation plot.

# create a correlation plot
shrooms_train_numeric %>%
  select(where(is.numeric)) %>%
  cor(use = "complete.obs") %>%
  corrplot(method = "number", order = 'FPC', type = "lower", diag = FALSE, tl.cex=0.7) 

The correlation plot depicts a strong positive relationship between spore print color and class. This supports our previous investigation of the variable spore print color. Similarly, there is a strong positive correlation between spore print color and the variables bruises and gill size.

Now that we have cleaned and examined the data, we can save the training set as a .rds file to be used later in our analysis.

write_rds(shrooms_train, "/Users/michellesack/Desktop/UCSB/Senior yr/Spring/PSTAT 131/Processed/shrooms_train_processed.rds")

write_rds(shrooms_train_factor, "/Users/michellesack/Desktop/UCSB/Senior yr/Spring/PSTAT 131/Processed/shrooms_train_factor_processed.rds")

write_rds(shrooms_train_numeric, "/Users/michellesack/Desktop/UCSB/Senior yr/Spring/PSTAT 131/Processed/shrooms_train_numeric_processed.rds")

write_rds(shrooms_test, "/Users/michellesack/Desktop/UCSB/Senior yr/Spring/PSTAT 131/Processed/shrooms_test_processed.rds")

write_rds(shrooms_test_factor, "/Users/michellesack/Desktop/UCSB/Senior yr/Spring/PSTAT 131/Processed/shrooms_test_factor_processed.rds")

write_rds(shrooms_test_numeric, "/Users/michellesack/Desktop/UCSB/Senior yr/Spring/PSTAT 131/Processed/shrooms_test_numeric_processed.rds")

Model Building

Let’s fold our training set using v-fold cross-validation, with v = 5 and stratifying on our outcome variable class.

# For classification models:
shrooms_folds <- vfold_cv(shrooms_train, v = 5, strata = class)

# For rgression models:
shrooms_folds_numeric <- vfold_cv(shrooms_train_numeric, v = 5, strata = class)

Let’s save these objects into a model fitting folder to maintain the same information across each R script model.

save(shrooms_folds, shrooms_train, file = "/Users/michellesack/Desktop/UCSB/Senior yr/Spring/PSTAT 131/Model Fitting/shrooms_model_setup.rda")

save(shrooms_folds_numeric, shrooms_train_numeric, file = "/Users/michellesack/Desktop/UCSB/Senior yr/Spring/PSTAT 131/Model Fitting/shrooms_numeric_model_setup.rda")

In this report I will examine six different models. Of these six models, two are regression while the other four are classification. The first two model’s I will investigate will be the following regression models:

  • Ridge Regression
  • Lasso Regression

Regression Models

Ridge Regression Model

In order to create a ridge regression model, we will use linear_reg() and set mixture = 0 to specify a ridge model. For this model we will use glmnet and tune the model to determine the optimal value of penalty by fitting it to the folds. To do this, I will set penalty = tune(). This informs tune_grid() that the penalty parameter must be tuned.

I will start by fitting a ridge regression to the entire training set.

ridge_recipe <- recipe(class ~ ., data = shrooms_train_numeric) %>% 
  step_dummy(all_nominal_predictors()) %>% 
  step_normalize(all_predictors()) %>%
  step_center(all_predictors())

ridge_spec <- linear_reg(penalty = tune(), mixture = 0) %>% 
  set_mode("regression") %>% 
  set_engine("glmnet")

# Now we combine to create a `workflow` object.
ridge_workflow <- workflow() %>% 
  add_recipe(ridge_recipe) %>% 
  add_model(ridge_spec)

Using grid_regular(), I will form a grid of evenly spaced parameter values that will be used in determining the optimal value of penalty.

penalty_grid <- grid_regular(penalty(range = c(-5, 5)), levels = 50)

# Now we have everything we need and we can fit all the models.
tune_res_ridge <- tune_grid(ridge_workflow,resamples = shrooms_folds_numeric, grid = penalty_grid)
## ! Fold1: internal: A correlation computation is required, but `estimate` is const...
## ! Fold2: internal: A correlation computation is required, but `estimate` is const...
## ! Fold3: internal: A correlation computation is required, but `estimate` is const...
## ! Fold4: internal: A correlation computation is required, but `estimate` is const...
## ! Fold5: internal: A correlation computation is required, but `estimate` is const...
save(tune_res_ridge, ridge_workflow, file = "/Users/michellesack/Desktop/UCSB/Senior yr/Spring/PSTAT 131/Model Fitting/tune_res_ridge.rda")

# Use `autoplot()` to create a visualization:
autoplot(tune_res_ridge)

With the graph, we can easily visualize how the amount of regularization affects the performance metrics. Notice how the amount of regularization in some areas has no influence on the coefficient estimates.

In order to view the raw metrics used in creating this chart, we can use the function collect_metrics(). We can select the best penalty value by using the function ‘select_best()’ and specifying the metric = “rsq”.

# select best rsq value
best_penalty <- select_best(tune_res_ridge, metric = "rsq")

Now that we have determined the optimal penalty value, we can use the finalize_workflow() function to update our recipe. We will replace tune() with our new optimal value of penalty and fit the new model to the entire training set.

ridge_final <- finalize_workflow(ridge_workflow, best_penalty)
ridge_final_fit <- fit(ridge_final, data = shrooms_train_numeric)

To validate the performance of the final model, I will use on the testing set.

shrooms_test_numeric$class <- as.numeric(shrooms_test_numeric$class)
ridge_metric <- augment(ridge_final_fit, new_data = shrooms_train_numeric) %>%
  rsq(truth = class, estimate = .pred)

Lasso Regression Model

In order to create a lasso regression model, I will use linear_reg() and set mixture = 1 to specify a lasso model. For this model we will use glmnet and tune the model to determine the optimal value of penalty by fitting it to the folds. To do this, I will set penalty = tune(). This informs tune_grid() that the penalty parameter must be tuned.

I will start by fitting a lasso regression to the entire training set.

lasso_recipe <- 
  recipe(formula = class ~ ., data = shrooms_train_numeric) %>% 
  step_dummy(all_nominal_predictors()) %>% 
  step_normalize(all_predictors()) %>%
  step_center(all_predictors())

lasso_spec <- 
  linear_reg(penalty = tune(), mixture = 1) %>% 
  set_mode("regression") %>% 
  set_engine("glmnet") 

lasso_workflow <- workflow() %>% 
  add_recipe(lasso_recipe) %>% 
  add_model(lasso_spec)

Using grid_regular(), I will form another grid of evenly spaced parameter values to use in determining the optimal value of penalty. For this grid, I will change the range to be from -3 to 1.

penalty_grid <- grid_regular(penalty(range = c(-3, 1)), levels = 50)

# Now we have everything we need and we can fit all the models.
tune_res_lasso <- tune_grid(lasso_workflow, resamples = shrooms_folds_numeric, grid = penalty_grid)
## ! Fold1: internal: A correlation computation is required, but `estimate` is const...
## ! Fold2: internal: A correlation computation is required, but `estimate` is const...
## ! Fold3: internal: A correlation computation is required, but `estimate` is const...
## ! Fold4: internal: A correlation computation is required, but `estimate` is const...
## ! Fold5: internal: A correlation computation is required, but `estimate` is const...
save(tune_res_lasso, lasso_workflow, file = "/Users/michellesack/Desktop/UCSB/Senior yr/Spring/PSTAT 131/Model Fitting/tune_res_lasso.rda")

# Use `autoplot()` to create a visualization:
autoplot(tune_res_lasso)

Once again, we will select the best penalty value by using the function ‘select_best()’ and specifying the metric = “rsq”.

best_penalty <- select_best(tune_res_lasso, metric = "rsq")

Create a new model with the optimal penalty value to refit using the entire training data set.

lasso_final <- finalize_workflow(lasso_workflow, best_penalty)
lasso_final_fit <- fit(lasso_final, data = shrooms_train_numeric)

And finally, we can validate the performance of the final model by predicting on the testing set.

lasso_metric <- augment(lasso_final_fit, new_data = shrooms_train_numeric) %>%
  rsq(truth = class, estimate = .pred)

Let’s create a table containing the estimates of each regression model

rsq_ridge <- max(ridge_metric$.estimate)
rsq_lasso <- max(lasso_metric$.estimate)
rsq <- bind_cols(rsq_ridge, rsq_lasso)
colnames(rsq) <-  c('Ridge Regression', 'Lasso Regression')
rsq
## # A tibble: 1 × 2
##   `Ridge Regression` `Lasso Regression`
##                <dbl>              <dbl>
## 1              0.585              0.586

Of these models, the Lasso regression model performed best

best <- select_best(tune_res_ridge, metric = "rsq")
ridge_best <- finalize_workflow(lasso_workflow, best)
ridge_train_final_fit <- fit(ridge_best, data = shrooms_test_numeric)

And finally, I will predict on the testing set:

final_predict <- augment(ridge_train_final_fit, new_data = shrooms_test_numeric) %>%
  rsq(truth = class, estimate = .pred)
final_predict
## # A tibble: 1 × 3
##   .metric .estimator .estimate
##   <chr>   <chr>          <dbl>
## 1 rsq     standard       0.581

Classification Models

logistics Regression Model

In order to create a logistics regression model, I will use logistic_reg() and the glm engine. I will create a workflow and add my model and the appropriate recipe. Finally, I will use fit() to fit the model to the folded data.

I will set up a recipe to predict class with cap_shape, bruises, odor, gill_size, gill_color, stalk_root, stalk_color_above_ring, stalk_color_below_ring, spore_print_color, population and habitat.

In this recipe, I will:

  • Dummy-code all nominal predictors
  • Normalize all predictors
  • Center all predictors
shrooms_recipe <- recipe(class ~ cap_shape + bruises + odor + gill_size + gill_color + stalk_root + stalk_color_above_ring + stalk_color_below_ring + spore_print_color + population + habitat, data = shrooms_train) %>% 
  step_dummy(all_nominal_predictors()) %>% 
  step_normalize(all_predictors()) %>%
  step_center(all_predictors())
# specify the model type to be logistic regression and engine to be glm
log_reg <- logistic_reg() %>% 
  set_engine("glm") %>% 
  set_mode("classification")

# set up the workflow and fit the model to the training data
log_wkflow <- workflow() %>% 
  add_model(log_reg) %>% 
  add_recipe(shrooms_recipe)

log_fit <- fit_resamples(log_wkflow, shrooms_folds)

We can use collect_metrics() to print the mean and standard errors of the performance metric accuracy across all folds.

log_metric <- collect_metrics(log_fit)

Random Forest Model

In order to set up a random forest model and workflow, I will use rand_forest() and the ranger engine, setting importance = "impurity". I will alo tune mtry, trees, and min_n.

mtr refers to the number of variables that we use at each split, while trees refers to the amount of trees used. Lastly, min_n refers to the minimum number of nodes.

rf_model <- rand_forest(min_n = tune(), mtry = tune(), trees = tune(), mode = "classification") %>% 
  set_engine("ranger")

rf_workflow <- workflow() %>% 
  add_model(rf_model) %>% 
  add_recipe(shrooms_recipe)

# Create a regular grid with 6 levels each:
rf_grid <- grid_regular(min_n(range= c(1, 10)), mtry(range= c(1, 7)), trees(range= c(1, 4)), levels = 6)

# Specify `roc_auc` as a metric. Tune the model:
tune_res_rf = tune_grid(rf_workflow, resamples = shrooms_folds, grid = rf_grid, metrics = metric_set(roc_auc))

save(tune_res_rf, rf_workflow, file = "/Users/michellesack/Desktop/UCSB/Senior yr/Spring/PSTAT 131/Model Fitting/tune_res_rf.rda")

# print an `autoplot()` of the result:
autoplot(tune_res_rf)

We will use collect_metrics() again to print the mean and standard errors of the performance metric, this time in descending order.

rf_metric <- arrange(collect_metrics(tune_res_rf), desc(mean))

Let’s create a variable importance plot using the function vip(), with our best-performing random forest model fit on the training set.

rf_spec_1 <- rand_forest(mtry = 7, trees = 3, min_n = 1) %>%
  set_engine("ranger", importance = 'impurity') %>%
  set_mode("classification")

rf_fit <- fit(rf_spec_1, class ~ cap_shape + bruises + odor + gill_size + gill_color + stalk_root + stalk_color_above_ring + stalk_color_below_ring + spore_print_color + population + habitat, data = shrooms_train_factor)

vip(rf_fit)

Clearly, the variable spore_print_color is the most important variable in predicting class with this model.

Boosted Trees

In order to set up a boosted tree model and workflow, I will use boost_tree() and the xgboost engine, tuning mtry, trees, and min_n.

boost_model <- boost_tree(min_n = tune(), mtry = tune(), trees = tune()) %>%
  set_engine("xgboost") %>%
  set_mode("classification")

boost_wf <- workflow() %>%
  add_recipe(shrooms_recipe) %>%
  add_model(boost_model)

# Create a regular grid with 6 levels each:
boost_grid <- grid_regular(min_n(range= c(1, 10)), mtry(range= c(1, 7)), trees(range= c(1, 4)), levels = 6)

# Specify `roc_auc` as a metric. Tune the model:
tune_res_boost <- tune_grid(boost_wf, resamples = shrooms_folds, grid = boost_grid, metrics = metric_set(roc_auc))

save(tune_res_boost, boost_wf, file = "/Users/michellesack/Desktop/UCSB/Senior yr/Spring/PSTAT 131/Model Fitting/tune_res_boost.rda")

# print an `autoplot()` of the result:
autoplot(tune_res_boost)

Retrieve the mean and standard errors of the performance metric in descending order.

boost_metric <- arrange(collect_metrics(tune_res_boost), desc(mean))

K-Nearest Neighbor Model

In order to set up a K-Nearest Neighbor Model and workflow, I will use nearest_neighbor() and the kknn engine, tuning neighbors and setting the mode to “classification”.

knn_model <- nearest_neighbor(neighbors = tune(), mode = "classification") %>% 
  set_engine("kknn")

knn_workflow <- workflow() %>% 
  add_model(knn_model) %>% 
  add_recipe(shrooms_recipe)

# Create a tuning grid:
knn_params <- extract_parameter_set_dials(knn_model)

# Create a regular grid with 9 levels each:
knn_grid <- grid_regular(knn_params, levels = 9)

# Specify `roc_auc` as a metric. Tune the model:
tune_res_knn <- tune_grid(knn_workflow, resamples = shrooms_folds, grid = knn_grid, metrics = metric_set(roc_auc))

save(tune_res_knn, knn_workflow, file = "/Users/michellesack/Desktop/UCSB/Senior yr/Spring/PSTAT 131/Model Fitting/tune_res_knn.rda")

# print an `autoplot()` of the result:
autoplot(tune_res_knn)

Select the metrics in descending order.

knn_metric <- arrange(collect_metrics(tune_res_knn), desc(mean))

Let’s display the ROC AUC values for your best-performing logistics regression, random forest, boosted tree, and tree models in a table.

roc_log <- max(log_metric$mean)
roc_rf <- max(rf_metric$mean)
rof_boost <- max(boost_metric$mean)
roc_knn <- max(knn_metric$mean)
roc <- bind_cols(roc_log, roc_rf, rof_boost, roc_knn)
colnames(roc) <-  c('Logistics Regression', 'Random Forest', 'Boosted', 'K-Nearest Neighbor')
roc
## # A tibble: 1 × 4
##   `Logistics Regression` `Random Forest` Boosted `K-Nearest Neighbor`
##                    <dbl>           <dbl>   <dbl>                <dbl>
## 1                      1               1   0.999                    1

The Logistics Regression, Random Forest and the K-Nearest Neighbor models worked equally best.

Final Model Building

For our final fit, I will use the Random Forest model. I will finalize the workflow by utilizing the parameters from the Random Forest model and fit the model to the entire training set.

best <- select_best(tune_res_rf, metric = "roc_auc")
rf_best <- finalize_workflow(rf_workflow, best)

# Use saved data set
shrooms_train_factor <- read_rds("/Users/michellesack/Desktop/UCSB/Senior yr/Spring/PSTAT 131/Processed/shrooms_train_factor_processed.rds")

rf_train_final_fit <- fit(rf_best, data = shrooms_train_factor)

# Save final training set fit
write_rds(rf_train_final_fit, "/Users/michellesack/Desktop/UCSB/Senior yr/Spring/PSTAT 131/Model Fitting/shrooms_train_factorc_processed.rds")

# Check the accuracy of the mdoel
predict(rf_train_final_fit, new_data = shrooms_train_factor, type = 'class') %>%
  bind_cols(shrooms_train_factor %>% select(class)) %>%
  accuracy(truth = class, estimate = .pred_class)
## # A tibble: 1 × 3
##   .metric  .estimator .estimate
##   <chr>    <chr>          <dbl>
## 1 accuracy binary             1

Finally, I will fit the model to the testing set and check it’s accuracy.

shrooms_test_factor <- read_rds("/Users/michellesack/Desktop/UCSB/Senior yr/Spring/PSTAT 131/Processed/shrooms_test_factor_processed.rds")

rf_test_final_fit <- fit(rf_best, data = shrooms_test_factor)
predict(rf_test_final_fit, new_data = shrooms_test_factor, type = 'class') %>%
  bind_cols(shrooms_test_factor %>% select(class)) %>%
  accuracy(truth = class, estimate = .pred_class)
## # A tibble: 1 × 3
##   .metric  .estimator .estimate
##   <chr>    <chr>          <dbl>
## 1 accuracy binary         0.998

Let’s check a few predictions

Prediction Number

Poisonous Mushroom
prediction_1 <- data.frame(
  cap_shape = "Convex",
  bruises = "Bruises",
  odor = "Pungent",
  gill_size = "Narrow",
  gill_color = "Brown",
  stalk_root = "Equal",
  stalk_color_above_ring = "White",
  stalk_color_below_ring = "White",
  spore_print_color = "Black",
  population = "Scattered",
  habitat = "Urban")

predict(rf_train_final_fit, prediction_1)
## # A tibble: 1 × 1
##   .pred_class
##   <fct>      
## 1 Poisonous

The mushroom chosen for this prediction was poisonous and contained a convex cap shape, pungent odor, narrow gill size, brown gill color, equal stalk root, white stalk color, black spore print color, bruises, and was found in an urban environment among other scattered mushrooms. From these characteristics, the model correctly predicted that the mushroom was poisonous!

Prediction Number Two

Edible Mushroom
prediction_2 <- data.frame(
  cap_shape = "Convex",
  bruises = "Bruises",
  odor = "Almond",
  gill_size = "Broad",
  gill_color = "Brown",
  stalk_root = "Club",
  stalk_color_above_ring = "White",
  stalk_color_below_ring = "White",
  spore_print_color = "Black",
  population = "Numerous",
  habitat = "Grasses")

predict(rf_train_final_fit, prediction_2)
## # A tibble: 1 × 1
##   .pred_class
##   <fct>      
## 1 Edible

The mushroom chosen for this prediction was edible and contained a convex cap shape, almond odor, broad gill size, brown gill color, club stalk root, white stalk color, black spore print color, bruises, and was found in a grass environment alongside numerous other mushrooms. From these characteristics, the model correctly predicted that the mushroom was edible!

Prediction Number Three

Poisonous Mushroom
prediction_3 <- data.frame(
  cap_shape = "Convex",
  bruises = "Bruises",
  odor = "Pungent",
  gill_size = "Narrow",
  gill_color = "Brown",
  stalk_root = "Equal",
  stalk_color_above_ring = "White",
  stalk_color_below_ring = "White",
  spore_print_color = "Black",
  population = "Scattered",
  habitat = "Grasses")

predict(rf_train_final_fit, prediction_3)
## # A tibble: 1 × 1
##   .pred_class
##   <fct>      
## 1 Poisonous

The mushroom chosen for this prediction was poisonous and contained a convex cap shape, pungent odor, narrow gill size, brown gill color, equal stalk root, white stalk color, black spore print color, bruises, and was found in a grass environment among other scattered mushrooms. From these characteristics, the model correctly predicted the mushroom was poisonous!

Prediction Number Four

Edible Mushroom
prediction_4 <- data.frame(
  cap_shape = "Convex",
  bruises = "Bruises",
  odor = "Almond",
  gill_size = "Broad",
  gill_color = "Brown",
  stalk_root = "Club",
  stalk_color_above_ring = "White",
  stalk_color_below_ring = "White",
  spore_print_color = "Black",
  population = "Scattered",
  habitat = "Meadows")

predict(rf_train_final_fit, prediction_4)
## # A tibble: 1 × 1
##   .pred_class
##   <fct>      
## 1 Edible

The mushroom chosen for this prediction was edible and contained a convex cap shape, almond odor, broad gill size, brown gill color, club stalk root, white stalk color, black spore print color, bruises, and was found in a meadow among other scattered mushrooms. From these characteristics, the model correctly predicted that the mushroom was edible!

Conclusion

The goal of this project was to create a machine learning model that is able to predict whether or not a mushroom is poisonous based of off its characteristics. The process was split up into several parts. First, I downloaded the “Mushroom Classification Updated Dataset” from kaggle, important it into r, and loaded the necessary packages. I then cleaned up the data set using the clean_names() function and by removing certain unwanted variables. Once the data was clean, I performed a data split, selecting 80% for the training set and the remaining 20% for the testing set. I then began my exploratory data analysis by inspecting a number of different charts relating to the distribution of poisonous mushrooms among each of the possible variables. From this analysis, I found that the odor and spore print color of a mushroom are the strongest indications of whether or not they are poisonous! The next step in my report was building the model. I decided to test 2 regression models and 4 classification models. For the regression models, I chose a ridge regression and a lasso regression, and for the classification models, I chose a logistics regression, random forest, boost tree, and k-nearest neighbor. I used v-fold cross validation on the training set, using 5 folds and stratifying on the outcome variable class. I built a recipe for both the regression models and the classification models and performed some tuning on the necessary variables. Once each model was fit and tested, I compared their accuracy’s to determine the best fit model. Surprisingly, the majority of the models performed with an extremely high accuracy.

For my final model fit, I decided to use the random forest model. I finalized the workflow using the finalize_workflow() and fit the model to the entire training set. I then fit the model to the testing set checked its accuracy. To my surprise, the model was 100% accurate! This is likely because there are several variables in the data set that are 100% dependent on one-another. For example, the data demonstrates that a mushroom with foul odor will always be poisonous. I also decided to conduct four separate tests where I provided the characteristics of a mushroom and tested if the model could correctly predict its class. The model determined the correct class for all four of these tests!

Now that I have built an effective model, if you ever end up on that remote island with only mushrooms as your source of food, feel free to use this model to predict which of those mushrooms are safe to eat!